import numpy as np
from collections import Counter
from .metrics import accuracy_score

def gini(y):
    counter = Counter(y)
    result = 0
    for v in counter.values():
        result += (v / len(y)) ** 2
    return 1 - result

def cut(X, y, d, v):
    ind_left = (X[:,d] <= v)
    ind_right = (X[:,d] > v)
    return X[ind_left], X[ind_right], y[ind_left], y[ind_right]

def try_split(X,y,min_samples_leaf):
    best_g = 1
    best_d = -1
    best_v = -1
    for d in range(X.shape[1]):
        sorted_index = np.argsort(X[:,d])
        for i in range(len(X) - 1):
            if X[sorted_index[i],d] == X[sorted_index[i+1],d]:
                continue
            v = (X[sorted_index[i],d] + X[sorted_index[i+1],d]) / 2
            #print('d={},v={}'.format(d,v))
            X_left,C_right,y_left,y_right = cut(X,y,d,v)
            g_all = gini(y_left) + gini(y_right)
            #print('d={},v={},g_all={}'.format(d,v,g_all))
            if g_all < best_g and len(y_left) >= min_samples_leaf and len(y_right) >= min_samples_leaf:
                best_g = g_all
                best_d = d
                best_v = v
    return best_g,best_d,best_v

class DecisionTreeClassifier():

    def __init__(self, max_depth=2, min_samples_leaf = 1):
        self.tree_ = None
        self.max_depth = max_depth
        self.min_samples_leaf = min_samples_leaf

    def fit(self, X, y):
        
        self.tree_ = self.create_tree(X,y)
        return self

    def predict(self, X):
        assert self.tree_ is not None, '请先调用fit()方法'
        return np.array([self._predict(x, self.tree_) for x in X])

    def _predict(self, x, node):
        if node.label is not None:
            return node.label
        if x[node.dim] <= node.v:
            return self._predict(x, node.children_left)
        else:
            return self._predict(x,node.children_right)

    def create_tree(self, X, y, current_depth = 1):
        
        if current_depth > self.max_depth:
            return None
        
        g, d, v = try_split(X,y,self.min_samples_leaf)
        if d == -1 or g == 0:
            return None
        node = Node(d, v, g)
        X_left, X_right, y_left, y_right = cut(X,y,d,v)
        
        node.children_left = self.create_tree(X_left, y_left, current_depth + 1)
        if node.children_left is None:
            label = Counter(y_left).most_common(1)[0][0]
            node.children_left = Node(l=label)
        #create_tree(X[X_right],y[y_right])
        
        node.children_right = self.create_tree(X_right, y_right, current_depth + 1)
        if node.children_right is None:
            label = Counter(y_right).most_common(1)[0][0]
            node.children_right = Node(l=label)
            
        return node

    def score(self,X_test,y_test):
        y_predict = self.predict(X_test)
        return accuracy_score(y_test, y_predict)

class Node():
    def __init__(self, d=None, v=None, g=None, l=None):
        self.dim = d
        self.v = v
        self.g = g
        self.label = l
        
        self.children_left = None
        self.children_right = None
        
    def __repr__(self):
        return 'Node(d={},v={},g={},l={})'.format(self.dim,self.v,self.g,self.label)
